import math
import os
import pennylane as qml
from pennylane import numpy as np
from optimizer_apfa import gd_optimizer_apfa
qubits_list = range(5, 16)
layers_list = range(5, 16)

iteration   = 100
n_time      = 20
n_check     = 20
lr          = 0.01
c_const     = 1.0


def enhanced_gaussian_init(n_params: int, n_layers: int, c: float = 1.0) -> np.ndarray:
    """Enhanced Gaussian Initialization (EGI).

    Each parameter θᵢ ~ 𝒩(0, γ²) with γ² = c / L.
    """
    if n_layers <= 0:
        raise ValueError("n_layers must be positive.")
    if not (0 < c <= 1):
        raise ValueError("c_const must satisfy 0 < c ≤ 1 to ensure γ ≤ 1.")

    gamma = math.sqrt(c / n_layers)
    return np.random.normal(0.0, gamma, n_params)


def build_variant_0(n_qubits: int, n_layers: int, dev):


    @qml.qnode(dev)
    def ansatz_variant_0(params):
        k = 0

        for _ in range(n_layers):
            for j in range(n_qubits):
                qml.CZ(wires=[j, (j + 1) % n_qubits])
            for j in range(n_qubits):
                qml.RX(2 * params[k], wires=j)
                k += 1
            for j in range(n_qubits):
                qml.RY(2 * params[k], wires=j)
                k += 1

        return [
            qml.expval(
                qml.PauliX(m) @ qml.PauliX((m + 1) % n_qubits) +
                qml.PauliY(m) @ qml.PauliY((m + 1) % n_qubits) +
                qml.PauliZ(m) @ qml.PauliZ((m + 1) % n_qubits)
            )
            for m in range(n_qubits - 1)
        ]

    return ansatz_variant_0


for n_qubits in qubits_list:
    for n_layers in layers_list:
        dev = qml.device("default.qubit", wires=n_qubits)


        noise_gamma = math.sqrt(1 / (8 * n_layers))

        ansatz = build_variant_0(n_qubits, n_layers, dev)
        n_params = 2 * n_layers * n_qubits


        init_weights = enhanced_gaussian_init(n_params, n_layers, c=c_const)

        for run_idx in range(n_time):
            loss_t, gradnorm_t, final_w_target, freeze_mask_list_t = gd_optimizer_apfa(
                ansatz=ansatz,
                weights=init_weights.copy(),
                noise_gamma=noise_gamma,
                lr=lr,
                iteration=iteration,
                n_check=n_check,
                alpha=0.7,
                freeze_factor=0.2,
                activate_factor=0.4,
                freeze_count_th=3,
                activate_count_th=2,
                warmup_steps=20,
            )


            folder_tgt = f"Dataset/q{n_qubits}_L{n_layers}"
            os.makedirs(folder_tgt, exist_ok=True)

            np.save(os.path.join(folder_tgt, f"loss_{run_idx}.npy"), loss_t)
            np.save(os.path.join(folder_tgt, f"grad_norm_{run_idx}.npy"), gradnorm_t)
            np.save(os.path.join(folder_tgt, f"weights_final_target_{run_idx}.npy"), final_w_target)
            np.save(os.path.join(folder_tgt, f"freeze_mask_{run_idx}.npy"), freeze_mask_list_t)
